Simple Non-Linear Setting

Non-Linear Model with Binary Treatment and Continuous Covariates

Authors
Affiliations

Nolan Cole

Department of Biostatistics, University of Washington

Lars van der Laan

Department of Statistics, University of Washington

Marco Carone

Department of Biostatistics, University of Washington

Department of Statistics, University of Washington

Published

February 22, 2026

Simulation Details


Data Generating Process

\begin{align*} W_1 &\sim \mathrm{Unif}(0,1), \\ W_2 &\sim \mathrm{Unif}(-1,1), \\ W_3 &\sim \mathrm{Unif}(-5,5), \\ W &= (W_1, W_2, W_3)^\top, \\ V &= W_1, \\ \varepsilon &\sim \mathrm{N}(0,\sigma^2), \\ \pi_0(W) &= P(A=1 \mid W) = \mathrm{logit}^{-1}\!\big(\gamma_0 + \gamma_1 W_1 + \gamma_2 W_2 + \gamma_3 W_3\big), \\ \tau_0(W) &= \tfrac{1}{2}\sin(2\pi W_1) + \tfrac{1}{2}\cos(\pi W_2), \\ \mu_{00}(W) &= -2 + 0.5\,W_1 - 0.25\,W_2 + 0.1\,W_3, \\ \mu_{01}(W) &= \mu_{00}(W) + \tau_0(W), \\ Y &= \mu_{0A}(W) + \varepsilon, \\ \tau(W) &= \mathbb{E}[Y \mid A=1, W] - \mathbb{E}[Y \mid A=0, W] = \tau_0(W), \\ \bar{\tau}_0(v) &= \mathbb{E}[\tau_0(W) \mid V=v] = \tfrac{1}{2}\sin(2\pi v). \end{align*}


Simulation Parameters

Parameter Value
\gamma_0 0
\gamma_1 -0.75
\gamma_2 0.5
\gamma_3 -0.25
\sigma 0.2
n_{\text{vals}} {250,\ 500,\ 1000,\ 2500,\ 10000}
t_{\text{subset}} {-0.6,\ -0.5,\ -0.25,\ 0,\ 0.25,\ 0.5,\ 0.6}
Support bounds [-0.5,\ 0.5]
\text{ngrid} 2500

Statistical Parameters


CATE

\tau_0(w) = E_0[Y | A=1, W=w] - E_0[Y | A=0, W=w]


V-specific CATE

\bar{\tau}_0(v) = E_0[ \tau_0(W) | V=v]

V-specific CATE CDF

\overline{\theta}_0(t) = E_0[1\{\bar{\tau}_0(V) \leq t\}]

V-specific CATE Primitive CDF

\overline{\Psi}_0(t) = E_0[1\{\bar{\tau}_0(V) \leq t\} \{t-\bar{\tau}_0(V)\}]


Remainder Terms for Primitive

\begin{align*} \Psi_t(P) - \Psi_t(P_0) + P_0 D(P) & = E_{0} \bigg[ 1\{\bar{\tau}_P \leq t < \bar{\tau}_0\} (\bar{\tau}_P - \bar{\tau}_0)\bigg] \\ & \quad + E_{0} \bigg[ 1\{\bar{\tau}_P < t \leq \bar{\tau}_0\} (t-\bar{\tau}_P)\bigg]\\ & \quad - E_{0} \bigg[ 1\{\bar{\tau}_0 < t \leq \bar{\tau}_P\} (t - \bar{\tau}_0) \bigg] \\ & \quad - \frac{1}{2} E_0\left[1\{\bar{\tau}_P = t\} \left( \bar{\tau}_P-\bar{\tau}_0 \right) \right] \\ & \quad + \bar{R}_1(P_0, P) \\ & = \bar{R}. \end{align*} where we have defined \bar{R}_1(P_0, P):= - E_{0} \bigg[ \left(1\{\bar{\tau}_P(V)<t\} + \frac{1\{\bar{\tau}_P(V) = t\}}{2} \right) \bigg( ( \frac{\pi_0(W)}{\pi_P(W)} - 1) (\mu_0(1,W) - \mu_P(1,W)) - (\frac{1 - \pi_0(W)}{1 - \pi_P(W)\}} - 1) (\mu_0(0,W) - \mu_P(0,W)) \bigg) \bigg].


Nuisance Parameters

Nuisance Parameter Algorithm Purpose / Rationale
\mu_P(A,W) (outcome regression) Stratified HAL (Lrnr_stratified + Lrnr_hal9001, max_degree=2, smoothness_orders=1), stratified on A Flexible nonparametric regression with oracle-rate guarantees; stratification avoids modeling treatment interactions
\pi_P(W) (propensity score) Logistic regression (GLM with logit link), model A \sim W_1 + W_2 + W_3 Correctly specified parametric model for stability and fast convergence
\tau_P(W) (CATE) Plug-in estimator \tau_n(W)=\mu_{n,1}(W)-\mu_{n,0}(W) Standard T-learner derived from outcome regression
\bar{\tau}_P(V) (V-specific CATE) DR-learner: regress pseudo-outcome \phi_n on V using kernel regression (safe_fk_regression) One-dimensional smoothing exploits V=W_1 structure and improves efficiency
\mathrm{Var}_P(Y\mid A,W) Stratified HAL fit to squared residuals (Y-\mu_n(A,W))^2 Flexible variance estimation needed for Chernoff scaling constants
c_P(W) \displaystyle c_n(W)=\frac{\widehat{\mathrm{Var}}(Y\mid A=1,W)(1-\pi_n(W))+\widehat{\mathrm{Var}}(Y\mid A=0,W)\pi_n(W)}{\pi_n(W)(1-\pi_n(W))} Efficient influence-function variance component
\mathbb{E}[c_P(W)\mid V] Kernel regression of c_n(W) on V (safe_fk_regression) Conditional expectation in one dimension for stability
\mathrm{Var}(\tau_P(W)\mid V) Kernel regression of (\tau_n(W)-\bar{\tau}_n(V))^2 on V (safe_fk_regression) Component of \bar{c}_n(V) decomposition
\bar{c}_P(V) \bar{c}_n(V)=\mathbb{E}[c_n(W)\mid V]+\mathrm{Var}(\tau_n(W)\mid V) Chernoff scaling variance term
\mathbb{E}[\bar{c}_P(V)\mid \bar{\tau}_P(V)] HAL regression (Lrnr_hal9001, max_degree=1, smoothness_orders=0) of \bar{c}_n(V) on \bar{\tau}_n(V) Smooth function with jump discontinuity.
f_{\bar{\tau}_P(V)} (density) Kernel density estimation via FKSUM::fk_density Required for Chernoff normalization constant

Appendix

Code

Code
# setup
knitr::opts_chunk$set(echo = FALSE, warning = FALSE, message = FALSE, cache = FALSE)

# load libraries
if(!requireNamespace("sl3")) remotes::install_github("tlverse/sl3", force = TRUE)
if(!requireNamespace("hal9001")) remotes::install_github("tlverse/hal9001", force = TRUE)
if(!requireNamespace("pacman")) install.packages(pacman)
library(pacman)
pacman::p_load(
  tidyverse, kableExtra, knitr,
  grf, remotes, gbm,
  data.table, origami,
  ranger, xgboost, randomForest,
  gt, latex2exp, rsample,
  hal9001, sl3, plotly
)

set.seed(2025)
# results_psi <- readRDS("/Users/nolan/Library/CloudStorage/OneDrive-UW/0_Research/Carone/primitive_cate/reports/chapters/2_simple_example/6.5_vcate_wave_psi.rds")
# results_theta <- readRDS("/Users/nolan/Library/CloudStorage/OneDrive-UW/0_Research/Carone/primitive_cate/reports/chapters/2_simple_example/6.5_vcate_wave_theta.rds")
results_theta <- results_psi <- readRDS("/Users/nolan/Library/CloudStorage/OneDrive-UW/0_Research/Carone/primitive_cate/reports/chapters/2_simple_example/2.1_compare_grids.rds")

fast_qchern <- readRDS("/Users/nolan/Library/CloudStorage/OneDrive-UW/0_Research/Carone/primitive_cate/reports/chapters/2_simple_example/fast_qchern.rds")

qchernoff_fast <- function(p) {
  if (any(p < 0.001 | p > 0.999, na.rm = TRUE)) {
    stop("qchernoff_fast only defined for p in [0.001, 0.999]")
  }
  fast_qchern(p)
}
attr(qchernoff_fast, "description") <- "Approximate Chernoff quantile: qchernoff_fast(p) ~ ChernoffDist::qchernoff(p)"

lower <- -0.5
upper <- 0.5
t_subset <- c(lower-0.1, seq(lower, upper, 0.25), upper+0.1) |> sort()

t_theta_vals <- t_subset # results_theta$t |> unique()
# qs_theta <- quantile(t_theta_vals, probs = c(0.15, 0.5, 0.85))
t_theta_subset <- t_theta_vals # sapply(qs_theta, \(q) t_theta_vals[which.min(abs(t_theta_vals - q))])

t_psi_vals <- t_theta_vals # results_psi$t |> unique()
# qs_psi <- quantile(t_psi_vals, probs = c(0.15, 0.5, 0.85))
t_psi_subset <- t_psi_vals # sapply(qs_psi, \(q) t_psi_vals[which.min(abs(t_psi_vals - q))])
##########
## DGM  ##
##########

# propensity model
gamma0 <- 0
gamma1 <- -0.75
gamma2 <-  0.5
gamma3 <- -0.25

# outcome noise SD
sigma <- 0.2

param_tvals <- c(seq(lower - 0.1, upper - 0.1, 0.01), t_theta_vals) |>
  unique() |>
  sort()
param_n <- 100000 # for parameter computation

#############################
## Data Generating Process ##
#############################

## generate data ##
W1 <- runif(param_n, 0, 1)
W2 <- runif(param_n, -1, 1)
W3 <- runif(param_n, -5, 5)
W  <- cbind("W1" = W1, "W2" = W2, "W3" = W3)
V  <- W1

# Propensity score
pi0 <- plogis(gamma0 + gamma1 * W1 + gamma2 * W2 + gamma3 * W3)
A   <- rbinom(param_n, 1, pi0)

# TRUE CATE and baseline, written inline:
# tau_0(W) = 0.5 * sin(2*pi*W1) + 0.5 * cos(pi*W2)
tau0 <- 0.5 * sin(2 * pi * W1) + 0.5 * cos(pi * W2)

# m_0(W) = -2 + 0.5*W1 - 0.25*W2 + 0.1*W3
mu_00 <- -2 + 0.5 * W1 - 0.25 * W2 + 0.1 * W3     # E[Y | A=0, W]
mu_01 <- mu_00 + tau0                             # E[Y | A=1, W]
mu_0  <- mu_00 + A * tau0                         # E[Y | A, W]

# TRUE V-specific CATE
bar_tau0 <- 0.5 * sin(2 * pi * V)

# Observed outcome
Y <- mu_0 + rnorm(param_n, mean = 0, sd = sigma)
YAW <- data.frame(Y = Y, A = A, W)
Y1W <- data.frame(Y = Y, A = 1, W)
Y0W <- data.frame(Y = Y, A = 0, W)

############################
## Statistical Parameters ##
############################

# Var_0(Y \mid A, W)
# YAW$resid2 <- (Y - mu_0)^2
# var0Y1W_fit <- lm(resid2 ~ W1 + W2 + W3, data = YAW |> filter(A==1))
# var0Y0W_fit <- lm(resid2 ~ W1 + W2 + W3, data = YAW |> filter(A==0))
# var0Y1W <- predict(var0Y1W_fit, newdata = data.frame(W1, W2, W3))
# var0Y0W <- predict(var0Y0W_fit, newdata = data.frame(W1, W2, W3))
# Under homoskedasticity
var0Y1W <- sigma^2
var0Y0W <- sigma^2

# c_0(W) := \frac{Var(Y | A=1,W)(1-\pi_0(W)) + Var(Y | A=0,W)\pi_0(W)}{\pi_0(W)(1-\pi_0(W))}
c0W <- (var0Y1W*(1-pi0) + var0Y0W*pi0) / (pi0 * (1-pi0))

# E_0[ c_0(W) | V]
# Derivation Note:
# 1. Under homoskedasticity (Var = sigma^2), c_0(W) simplifies to:
#    c_0(W) = sigma^2 * [ (1-pi_0)/pi_0 + pi_0/(1-pi_0) ]
#    Since pi_0 = expit(logit_pi), this is sigma^2 * (exp(-logit_pi) + exp(logit_pi))
#    Which is equivalent to: sigma^2 * 2 * cosh(gamma0 + gamma1*W1 + gamma2*W2 + gamma3*W3)
# 2. To find E[c_0(W) | V=v], we integrate over W2 ~ U(-1, 1) and W3 ~ U(-5, 5).
# 3. The integral of cosh(a + bx) over U(L, U) is: 
#    [sinh(a + bU) - sinh(a + bL)] / (b * (U - L))
#    Using sum-to-product identities, this yields the sinh(x)/x (sinc) terms.
E_c0_V_closed <- function(v) {
  
  sinh_over_x <- function(x) ifelse(abs(x) < 1e-12, 1, sinh(x) / x)
  
  a <- gamma0 + gamma1 * v
  K <- sinh_over_x(gamma2) * sinh_over_x(5 * gamma3)
  sigma^2 * (2 + 2 * K * cosh(a))
}
E_c0_V_closed <- Vectorize(E_c0_V_closed, vectorize.args = "v")
E_c0_V <- E_c0_V_closed(V) # FKSUM::fk_regression(V, c0W, h = "cv", type = 'loc-lin')

# Var(\tau_0(W) | V)
# Under this DGM, Var(\tau_0(W) | V) = 0 + 0.25 * E[cos^2(pi W_2)] = 0.25 * 0.5
var0_tau0_V <- 0.125 # FKSUM::fk_regression(V, (tau0 - bar_tau0)^2, ngrid = 5000)

# \bar{c}_0 := E_0[ c_0(W) | V] + Var(\tau_0(W) | V)
barc0V <- E_c0_V + var0_tau0_V

# E[\bar{c}_0(V) \mid \bar{\tau}_0(V)]
E_barc0_bar_tau0 <- FKSUM::fk_regression(bar_tau0, barc0V, h = "cv", type = 'loc-lin')
E_barc0_given_bartau0 <- function(t) {
  # 1. Map bartau0 back to the two possible roots of V in [0, 1]
  # bartau0 = 0.5 * sin(2 * pi * V) => sin(2 * pi * V) = 2t
  # We use pmin/pmax to prevent NaN from tiny floating point overflows
  theta <- asin(pmin(pmax(2 * t, -1), 1))
  
  v1 <- (theta / (2 * pi)) %% 1
  v2 <- ((pi - theta) / (2 * pi)) %% 1
  
  # 2. Compute bar_c0(V) for both roots
  # Uses your existing global: E_c0_V_closed(v) and var0_tau0_V
  c_v1 <- E_c0_V_closed(v1) + var0_tau0_V
  c_v2 <- E_c0_V_closed(v2) + var0_tau0_V
  
  # 3. Conditional expectation is the arithmetic mean of the two branches
  return(0.5 * (c_v1 + c_v2))
}
E_barc0_given_bartau0 <- Vectorize(E_barc0_given_bartau0, vectorize.args = "t")
# E[\bar{c}_0(W) | \bar{\tau}_0 = t] evaluated on param_tvals via interpolation
E_barc0_bar_tau0_t_vals <- E_barc0_given_bartau0(param_tvals)

# True V-specific CATE density at t over param_tvals
f_0bar_tau0_t_vals <- ifelse(
  abs(param_tvals) < 0.5,
  2 / (pi * sqrt(1 - 4 * param_tvals^2)),
  0
)

# True V-specific CATE CDF at t over param_tvals
bar_theta0_t_vals <- ifelse(
  param_tvals < -0.5,
  0,
  ifelse(
    param_tvals > 0.5,
    1,
    0.5 + (1 / pi) * asin(2 * param_tvals)
  )
)

parameter_df <- tibble(
  t = param_tvals
) %>%
  mutate(
    # \bar{\theta}_0 (CDF of V-specific CATE)
    bar_theta0 = bar_theta0_t_vals,
    
    # \bar{\Psi}_t(P_0) (V-specific CATE primitive)
    bar_Psi_t_P0 = map_dbl(t, ~ mean((.x - bar_tau0) * (bar_tau0 <= .x))),
    
    # \bar{D}_t(P_0) (true SD of IF)
    SD_bar_D_t0 = map_dbl(
      t,
      ~ sd(
        -((bar_tau0 < .x) + 0.5 * (bar_tau0 == .x)) *
          ((A - pi0) / (pi0 * (1 - pi0)) * (Y - mu_0) + tau0 - bar_tau0) +
          (bar_tau0 <= .x) * (.x - bar_tau0) -
          mean((.x - bar_tau0) * (bar_tau0 <= .x))
      )
    ),
    
    # V-specific CATE density at t (already evaluated on param_tvals)
    f_0bar_tau0_t = f_0bar_tau0_t_vals,
    
    # E[\bar{c}_0(W) | \bar{\tau}_0 = t] on param_tvals
    E_barc0_bar_tau0_t = E_barc0_bar_tau0_t_vals,
    
    # \bar{\kappa}_0(t)
    bar_kappa0 = E_barc0_bar_tau0_t * f_0bar_tau0_t,
    
    # Chernoff scaling constant: \bar{\rho}_0(t)
    bar_rho0 = (4 * f_0bar_tau0_t * bar_kappa0)^(1/3)
  )


# Grid over W1 in [0,1] and W2 in [-1,1]
W1_grid <- seq(0, 1, length.out = 51)
W2_grid <- seq(-1, 1, length.out = 51)

# Create matrix of tau_0(W1, W2) values
# tau_0(W) = 0.5 * sin(2*pi*W1) + 0.5 * cos(pi*W2)
tau_mat <- outer(
  W1_grid,
  W2_grid,
  function(w1, w2) 0.5 * sin(2 * pi * w1) + 0.5 * cos(pi * w2)
)

## 3D surface plot with sign-based coloring

plot_ly() |>
  add_surface(
    x = ~W1_grid,
    y = ~W2_grid,
    z = ~tau_mat,
    colorscale = list(
      c(0.0, "red"),      # adverse effect (red)
      c(0.5, "#F0F0F0"),  # around zero (light gray)
      c(1.0, "blue")      # positive effect (blue)
    )
  ) |>
  layout(
    scene = list(
      xaxis = list(title = "W1"),
      yaxis = list(title = "W2"),
      zaxis = list(title = latex2exp::TeX("$\\tau_0(W)$"))
    ),
    title = "True CATE surface: tau_0(W) = 0.5 sin(2*pi*W1) + 0.5 cos(pi*W2)"
  )
## Heatmap with the same sign-based color palette

plot_ly(
  x = ~W1_grid,
  y = ~W2_grid,
  z = ~tau_mat
) |>
  add_heatmap(
    colorscale = list(
      c(0.0, "red"),      # negative (red)
      c(0.5, "#F0F0F0"),  # around zero (light gray)
      c(1.0, "blue")      # positive (blue)
    )
  ) |>
  layout(
    xaxis = list(title = "W1"),
    yaxis = list(title = "W2"),
    title = "True CATE heatmap: tau_0(W)"
  )
data.frame(
  V = V,
  bar_tau0 = bar_tau0
) %>%
  ggplot(aes(x = V, y = bar_tau0)) +
  geom_line() +
  geom_hline(yintercept = 0, color = "red", linetype = "dashed") +
  labs(
    x = expression(V == W[1]),
    y = expression(bar(tau)[0](V)),
    title = expression("True V-specific CATE " ~ bar(tau)[0](V))
  ) +
  theme_bw() +
  theme(aspect.ratio = 1)



ggplot(parameter_df, aes(x=t, y = bar_theta0)) +
  geom_line() +
  theme_bw() +
  labs(
    x = latex2exp::TeX("$t$"),
    y = latex2exp::TeX("$\\bar{\\theta}_P(t)$"),
    title = "CDF of V-specific CATE",
    subtitle = latex2exp::TeX("$P(\\bar{\\tau}_P(V) \\leq t)$")
  ) +
  theme(aspect.ratio = 1)

all_metric <- results_theta %>%
  select(t, n, grid, bar_theta_n, bar_theta_os) %>%
  left_join(parameter_df %>% select(t, bar_theta0), by = "t") %>%
  pivot_longer(
    cols = matches("^bar_theta_(n|os|tmle)$"),  # avoids accidentally grabbing bar_theta0
    names_to = "estimator",
    values_to = "bar_theta_est"
  ) %>%
  group_by(t, n, grid, estimator) %>%
  summarise(
    bias = mean(bar_theta_est - bar_theta0, na.rm = TRUE),
    var  = var(bar_theta_est, na.rm = TRUE),
    mse  = mean((bar_theta_est - bar_theta0)^2, na.rm = TRUE),
    .groups = "drop"
  ) %>%
  mutate(
    sc_bias2 = n^(2/3) * bias^2,
    sc_var   = n^(2/3) * var,
    sc_mse   = n^(2/3) * mse
  ) %>%
  select(-bias, -var, -mse) %>%
  filter(
    t %in% t_theta_subset
  ) %>%
  pivot_longer(
    cols = c(sc_bias2, sc_var, sc_mse),
    names_to = "metric",
    values_to = "value"
  ) %>%
  mutate(
    estimator = recode(
      estimator,
      bar_theta_n    = "Plug-in",
      bar_theta_os   = "One-step",
      bar_theta_tmle = "TMLE"
    ),
    metric_label = recode(
      metric,
      sc_bias2 = "n^{2/3}*Bias^2",
      sc_var   = "n^{2/3}*Variance",
      sc_mse   = "n^{2/3}*MSE"
    ),
    metric_label = factor(
      metric_label,
      levels = c("n^{2/3}*Bias^2", "n^{2/3}*Variance", "n^{2/3}*MSE")
    )
  )

# LEFT
all_metric %>%
  filter(grid == "left") %>%
  ggplot(
    aes(x = n, y = value, color = estimator, shape = estimator, linetype = estimator)
  ) +
  geom_line(linewidth = 1) +
  geom_point(size = 2) +
  geom_hline(yintercept = 0, color = "black") +
  facet_grid(metric_label ~ t, scales = "free_y", labeller = label_parsed) +
  theme_bw() +
  labs(
    x = "Sample Size (n)",
    y = NULL,
    color = "Estimator",
    linetype = "Estimator",
    shape = "Estimator",
    title = "Left grid"
  ) +
  scale_color_manual(values = c(
    "Plug-in"  = "plum",
    "One-step" = "#0072B2",
    "TMLE"     = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in"  = "dotted",
    "One-step" = "solid",
    "TMLE"     = "longdash"
  )) +
  scale_shape_manual(values = c(
    "Plug-in"  = 16,
    "One-step" = 17,
    "TMLE"     = 15
  )) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )

# FINE
all_metric %>%
  filter(grid == "fine") %>%
  ggplot(
    aes(x = n, y = value, color = estimator, shape = estimator, linetype = estimator)
  ) +
  geom_line(linewidth = 1) +
  geom_point(size = 2) +
  geom_hline(yintercept = 0, color = "black") +
  facet_grid(metric_label ~ t, scales = "free_y", labeller = label_parsed) +
  theme_bw() +
  labs(
    x = "Sample Size (n)",
    y = NULL,
    color = "Estimator",
    linetype = "Estimator",
    shape = "Estimator",
    title = "Fine grid"
  ) +
  scale_color_manual(values = c(
    "Plug-in"  = "plum",
    "One-step" = "#0072B2",
    "TMLE"     = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in"  = "dotted",
    "One-step" = "solid",
    "TMLE"     = "longdash"
  )) +
  scale_shape_manual(values = c(
    "Plug-in"  = 16,
    "One-step" = 17,
    "TMLE"     = 15
  )) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )

# INTERVAL
all_metric %>%
  filter(grid == "interval") %>%
  ggplot(
    aes(x = n, y = value, color = estimator, shape = estimator, linetype = estimator)
  ) +
  geom_line(linewidth = 1) +
  geom_point(size = 2) +
  geom_hline(yintercept = 0, color = "black") +
  facet_grid(metric_label ~ t, scales = "free_y", labeller = label_parsed) +
  theme_bw() +
  labs(
    x = "Sample Size (n)",
    y = NULL,
    color = "Estimator",
    linetype = "Estimator",
    shape = "Estimator",
    title = "Interval grid"
  ) +
  scale_color_manual(values = c(
    "Plug-in"  = "plum",
    "One-step" = "#0072B2",
    "TMLE"     = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in"  = "dotted",
    "One-step" = "solid",
    "TMLE"     = "longdash"
  )) +
  scale_shape_manual(values = c(
    "Plug-in"  = 16,
    "One-step" = 17,
    "TMLE"     = 15
  )) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )

results_theta_coverage <- results_theta %>%
  left_join(
    parameter_df %>% select(t, bar_theta0),
    by = c("t")
  ) %>%
  select(
    t, n, grid,
    contains("bar_theta"),
    # bar_theta0,
    # bar_theta_n, bar_theta_os, #bar_theta_tmle,
    bar_rhon, bar_rhon_plugin
  ) %>%
  # Long format for the three estimators' point estimates
  pivot_longer(
    cols         = contains("bar_theta_"),
    names_to     = "estimator",
    names_prefix = "bar_theta_",
    values_to    = "theta_hat"
  ) %>%
  mutate(
    
    # plug-in will have normal dist, OS will have chernoff
    se = if_else(estimator == "n", bar_rhon_plugin/n^(1/2), bar_rhon / n^(1/3)),
    
    # 97.5% Chernoff quantile with the factor 2 built in
    q975        = if_else(estimator == "n", 1.96, fast_qchern(0.975)),
    
    # Wald-type Chernoff CIs
    ci_lower    = theta_hat - q975 * se,
    ci_upper    = theta_hat + q975 * se,
    
    # Coverage indicator
    covered     = (ci_lower <= bar_theta0 & bar_theta0 <= ci_upper),
    
    # Nice estimator labels for plotting
    estimator   = dplyr::recode(
      estimator,
      "n"    = "Plug-in",
      "os"   = "One-step",
      "tmle" = "TMLE"
    )
  ) %>%
  group_by(n, t, grid, estimator) %>%
  summarise(
    coverage = mean(covered, na.rm = TRUE),
    .groups  = "drop"
  )

results_theta_coverage %>%
  filter(
    t %in% t_theta_subset
  ) %>%
  ggplot(
    .,
    aes(
      x        = n,
      y        = coverage,
      color    = estimator,
      linetype = estimator,
      shape    = estimator
    )
  ) +
  geom_line(linewidth = 1) +
  geom_point(size = 2) +
  geom_hline(yintercept = 0.95, linetype = "dashed", color = "black") +
  facet_grid(grid ~ t) +
  theme_bw() +
  labs(
    x        = "Sample Size",
    y        = "Empirical 95% Coverage",
    color    = "Estimator",
    linetype = "Estimator",
    shape    = "Estimator",
    title    = latex2exp::TeX("Empirical Coverage of 95\\% Confidence Intervals for $\\bar{\\theta}_{t,0}$")
  ) +
  scale_color_manual(values = c(
    "Plug-in" = "plum",
    "One-step" = "#0072B2",
    "TMLE"    = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in" = "dotted",
    "One-step" = "solid",
    "TMLE"    = "longdash"
  )) +
  scale_shape_manual(values = c(
    "Plug-in" = 16,  # filled circle
    "One-step" = 17, # filled triangle
    "TMLE"    = 15   # filled square
  )) +
  theme(
    aspect.ratio = 1,
    axis.text.x  = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )

results_theta_coverage %>%
  ggplot(aes(
    x = t, y = coverage,
    color    = estimator,
    linetype = estimator,
    shape    = estimator
  )) +
  geom_line(linewidth = 1) +
  geom_hline(yintercept = 0.95) +
  facet_grid(~ n) +
  scale_x_continuous(
    limits = c(min(t_theta_vals), max(t_theta_vals)),
    breaks = seq(min(t_theta_vals), max(t_theta_vals)+0.2, by = 0.2)
  ) +
  labs(
    x        = "t",
    y        = "Empirical 95% Coverage",
    color    = "Estimator",
    linetype = "Estimator",
    shape    = "Estimator",
    title    = latex2exp::TeX("Empirical Coverage of 95\\% Confidence Intervals for $\\bar{\\theta}_{t,0}$")
  ) +
  scale_color_manual(values = c(
    "Plug-in" = "plum",
    "One-step" = "#0072B2",
    "TMLE"    = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in" = "dotted",
    "One-step" = "solid",
    "TMLE"    = "longdash"
  )) +
  scale_shape_manual(values = c(
    "Plug-in" = 16,  # filled circle
    "One-step" = 17, # filled triangle
    "TMLE"    = 15   # filled square
  )) +
  theme_bw() +
  theme(
    aspect.ratio = 1,
    axis.text.x  = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )
results_theta_ciwidth <- results_theta %>%
  left_join(
    parameter_df %>% select(t, bar_theta0),
    by = c("t")
  ) %>%
  select(
    t, n, grid,
    bar_theta0,
    bar_theta_n, bar_theta_os, # bar_theta_tmle,
    bar_rhon, bar_rhon_plugin
  ) %>%
  pivot_longer(
    cols         = contains("bar_theta_"),
    names_to     = "estimator",
    names_prefix = "bar_theta_",
    values_to    = "theta_hat"
  ) %>%
  mutate(
    # plug-in will have normal dist, OS will have chernoff
    se = if_else(estimator == "n", bar_rhon_plugin/n^(1/2), bar_rhon / n^(1/3)),
    
    # 97.5% Chernoff quantile with the factor 2 built in
    q975        = if_else(estimator == "n", 1.96, fast_qchern(0.975)),
    
    # Wald-type Chernoff CIs
    ci_lower    = theta_hat - q975 * se,
    ci_upper    = theta_hat + q975 * se,
    
    # CI width
    ci_width = ci_upper - ci_lower,
    
    # nice labels
    estimator = dplyr::recode(
      estimator,
      "n"    = "Plug-in",
      "os"   = "One-step",
      "tmle" = "TMLE"
    )
  ) %>%
  group_by(n, t, grid, estimator) %>%
  summarise(
    mean_ci_width   = mean(ci_width, na.rm = TRUE),
    median_ci_width = median(ci_width, na.rm = TRUE),
    .groups = "drop"
  )

results_theta_ciwidth %>%
  filter(t %in% t_theta_subset) %>%
  ggplot(
    aes(
      x        = n,
      y        = mean_ci_width,   # swap to median_ci_width if you prefer
      color    = estimator,
      linetype = estimator,
      shape    = estimator
    )
  ) +
  geom_line(linewidth = 1) +
  geom_point(size = 2) +
  geom_hline(yintercept = 0, color = "black") +
  facet_grid(grid ~ t) +
  theme_bw() +
  labs(
    x        = "Sample Size",
    y        = "Mean 95% CI Width",
    color    = "Estimator",
    linetype = "Estimator",
    shape    = "Estimator",
    title    = latex2exp::TeX("Mean Width of 95\\% Confidence Intervals for $\\bar{\\theta}_{t,0}$")
  ) +
  scale_color_manual(values = c(
    "Plug-in"  = "plum",
    "One-step" = "#0072B2",
    "TMLE"     = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in"  = "dotted",
    "One-step" = "solid",
    "TMLE"     = "longdash"
  )) +
  scale_shape_manual(values = c(
    "Plug-in"  = 16,
    "One-step" = 17,
    "TMLE"     = 15
  )) +
  theme(
    aspect.ratio = 1,
    axis.text.x  = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )

# Compute n^(1/3)-standardized quantities
results_theta_standardized <- results_theta %>%
  left_join(
    parameter_df %>% select(t, bar_theta0),
    by = c("t")
  ) %>%
  mutate(
    z_Pn   = n^(1/3) * (bar_theta_n   - bar_theta0) / bar_rhon_plugin,
    z_Pn_gcm   = n^(1/2) * (bar_theta_gcm_n   - bar_theta0) / bar_rhon_plugin,
    z_os   = n^(1/3) * (bar_theta_os - bar_theta0) / bar_rhon,
    # z_tmle = n^(1/3) * (bar_theta_tmle - bar_theta0) / bar_rhon
  ) %>%
  select(t, n, grid, starts_with("z_")) %>%
  pivot_longer(
    cols = starts_with("z_"),
    names_to = "estimator",
    values_to = "z_value"
  ) %>%
  mutate(
    estimator = recode(
      estimator,
      "z_Pn"   = "Plug-in",
      "z_Pn_gcm" = "Plug-in+GCM",
      "z_os"   = "One-step",
      "z_tmle" = "TMLE"
    )
  )

# Standard normal reference curve
chern_ref <- tibble(
  x = seq(-4, 4, length.out = 400),
  density = ChernoffDist::dChern(seq(-4, 4, length.out = 400))
)


dist_plot_for_grid <- function(dat, grid_val, chern_ref) {
  dat %>%
    filter(grid == grid_val) %>%
    ggplot(aes(x = z_value, fill = n)) +
    geom_histogram(
      aes(y = after_stat(density)),
      position = "identity",
      alpha = 0.4,
      color = "black",
      bins = 40
    ) +
    geom_line(
      data = chern_ref,
      aes(x = x, y = density),
      color = "black",
      linewidth = 1,
      linetype = "solid",
      inherit.aes = FALSE
    ) +
    geom_vline(xintercept = 0, linetype = "dashed") +
    facet_grid(estimator ~ t, scales = "free") +
    theme_bw() +
    labs(
      x = latex2exp::TeX("$n^{1/3} (\\bar{\\theta}_{t,n} - \\bar{\\theta}_{t,0}) / \\rho_n$"),
      y = "Density",
      fill = "Sample size (n)",
      title = paste("Empirical Distributions with Standard Chernoff Overlay — grid:", grid_val),
      subtitle = latex2exp::TeX("Standardized Estimators of $\\bar{\\theta}_{t,0}$")
    ) +
    scale_fill_brewer(palette = "Set2") +
    theme(
      strip.text = element_text(size = 10),
      aspect.ratio = 1,
      legend.position = "bottom"
    )
}

# Make one plot per grid ---
grid_vals <- sort(unique(results_theta_standardized$grid))
names(grid_vals) <- as.character(grid_vals)

subset_dist <- results_theta_standardized %>%
  filter(
    n %in% 2500,
    t %in% c(-0.25, 0, 0.25)
  ) %>%
  mutate(n = factor(n))   # <-- key

dist_plots <- map(
  grid_vals,
  ~ dist_plot_for_grid(subset_dist, .x, chern_ref)
)
# Print them all
walk(dist_plots, print)
results_theta_qq <- results_theta %>%
  left_join(
    parameter_df %>% select(t, bar_theta0),
    by = c("t")
  ) %>%
  mutate(
    # Chernoff statistics:
    chernoff_stat_n    = n^(1/3) * (bar_theta_n - bar_theta0) / bar_rhon,    
    chernoff_stat_os   = n^(1/3) * (bar_theta_os - bar_theta0) / bar_rhon,   
    # chernoff_stat_tmle = n^(1/3) * (bar_theta_tmle - bar_theta0) / bar_rhon
  ) %>%
  select(t, n, grid, contains("chernoff_stat_")) %>%
  drop_na() %>%
  pivot_longer(
    contains("chernoff_stat_"), #c(chernoff_stat_n, chernoff_stat_os, chernoff_stat_tmle),
    names_to  = "estimator",
    values_to = "sample"
  ) %>%
  group_by(n, t, grid, estimator) %>%
  arrange(sample, .by_group = TRUE) %>%          # sort empirical quantiles
  mutate(
    prob_seq = ppoints(dplyr::n()),              # ppoints per (n, t, estimator)
    theor    = fast_qchern(prob_seq)             # standardized chernoff
  ) %>%
  ungroup()


# 2) Plot function for one grid value
qq_plot_for_grid <- function(dat, grid_val) {
  dat %>%
    filter(grid == grid_val) %>%
    ggplot(aes(x = theor, y = sample, color = estimator)) +
    geom_point(alpha = 0.5, size = 0.8) +
    geom_abline(intercept = 0, slope = 1, linetype = "dashed") +
    facet_grid(n ~ t, scales = "free_y") +
    labs(
      x = "Theoretical Quantiles",
      y = "Empirical Quantiles",
      title = paste("Chernoff QQ plots — grid:", grid_val)
    ) +
    scale_color_discrete(
      name   = "Estimator",
      breaks = c("chernoff_stat_n", "chernoff_stat_os", "chernoff_stat_tmle"),
      labels = c("Plug-in", "One-step", "TMLE")
    ) +
    theme_bw() +
    theme(
      aspect.ratio    = 1,
      legend.position = "bottom"
    )
}

# 3) Make the 3 plots for each grid
grid_vals <- sort(unique(results_theta_qq$grid))
names(grid_vals) <- as.character(grid_vals)
qq_plots <- map(grid_vals, ~ qq_plot_for_grid(
  results_theta_qq %>%
    filter(
      n %in% c(250, 2500, 10000),
      t %in% t_theta_subset
    ),
  .x))

# Or print them all
walk(qq_plots, print)


ggplot(parameter_df, aes(x=t, y = bar_Psi_t_P0)) +
  geom_line() +
  theme_bw() +
  labs(
    x = latex2exp::TeX("$t$"),
    y = latex2exp::TeX("$\\bar{\\Psi}_P(t)$"),
    title = "Primitive of the V-specific CATE CDF",
    subtitle = latex2exp::TeX("$E[1\\{\\bar{\\tau}_P(V) \\leq t\\} \\{t - \\bar{\\tau}_P(V)\\} ]$")
  ) +
  theme(aspect.ratio = 1)

all_metric_psi <- results_psi %>%
  select(t, n, grid, bar_Psi_Pn, bar_Psi_t_os) %>%
  left_join(parameter_df %>% select(t, bar_Psi_t_P0), by = "t") %>%
  pivot_longer(
    cols = matches("^bar_Psi_(Pn|t_os|t_tmle)$"),  # avoids accidentally grabbing bar_theta0
    names_to = "estimator",
    values_to = "bar_psi_est"
  ) %>%
  group_by(t, n, grid, estimator) %>%
  summarise(
    bias = mean(bar_psi_est - bar_Psi_t_P0, na.rm = TRUE),
    var  = var(bar_psi_est, na.rm = TRUE),
    mse  = mean((bar_psi_est - bar_Psi_t_P0)^2, na.rm = TRUE),
    .groups = "drop"
  ) %>%
  mutate(
    sc_bias2 = n * bias^2,
    sc_var   = n * var,
    sc_mse   = n * mse
  ) %>%
  select(-bias, -var, -mse) %>%
  filter(
    t %in% t_psi_subset
  ) %>%
  pivot_longer(
    cols = c(sc_bias2, sc_var, sc_mse),
    names_to = "metric",
    values_to = "value"
  ) %>%
  mutate(
    estimator = recode(
      estimator,
      bar_Psi_Pn    = "Plug-in",
      bar_Psi_t_os   = "One-step",
      bar_Psi_t_tmle = "TMLE"
    ),
    metric_label = recode(
      metric,
      sc_bias2 = "n*Bias^2",
      sc_var   = "n*Variance",
      sc_mse   = "n*MSE"
    ),
    metric_label = factor(
      metric_label,
      levels = c("n*Bias^2", "n*Variance", "n*MSE")
    )
  )

# INTERVAL
all_metric_psi %>%
  filter(grid == "interval") %>%
  ggplot(
    aes(x = n, y = value, color = estimator, shape = estimator, linetype = estimator)
  ) +
  geom_line(linewidth = 1) +
  geom_point(size = 2) +
  geom_hline(yintercept = 0, color = "black") +
  facet_grid(metric_label ~ t, scales = "free_y", labeller = label_parsed) +
  theme_bw() +
  labs(
    x = "Sample Size (n)",
    y = NULL,
    color = "Estimator",
    linetype = "Estimator",
    shape = "Estimator",
    # title = latex2exp::TeX("$\\bar{\\Psi}_{t,0}$")
  ) +
  scale_color_manual(values = c(
    "Plug-in"  = "plum",
    "One-step" = "#0072B2",
    "TMLE"     = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in"  = "dotted",
    "One-step" = "solid",
    "TMLE"     = "longdash"
  )) +
  scale_shape_manual(values = c(
    "Plug-in"  = 16,
    "One-step" = 17,
    "TMLE"     = 15
  )) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )

# Compute empirical coverage for each estimator
results_psi_coverage <- results_psi %>%
  filter(grid == "interval") %>%
  left_join(
    parameter_df,
    by = c("t")
  ) %>%
  select(
    t, n, bar_Psi_t_P0,
    starts_with("bar_Psi_"),
    SD_bar_D_tn, SD_Psi_Pn
  ) %>%
  # Long format for the three estimators' point estimates
  pivot_longer(
    cols         = c(bar_Psi_Pn, bar_Psi_t_os), # bar_Psi_t_tmle),
    names_to     = "estimator",
    names_prefix = "bar_Psi_",
    values_to    = "Psi_hat"
  ) %>%
  mutate(
    # use plug-in SD for plug-in confidence intervals
    se = if_else(estimator == "n", SD_Psi_Pn/n^(1/2), SD_bar_D_tn / n^(1/2)),
    
    # Wald-type CIs
    ci_lower    = Psi_hat   - 1.96 * se,
    ci_upper    = Psi_hat   + 1.96 * se,
    
    # Coverage indicator
    covered     = (ci_lower <= bar_Psi_t_P0 & bar_Psi_t_P0 <= ci_upper),
    
    # Nice estimator labels for plotting
    estimator   = dplyr::recode(
      estimator,
      "Pn"    = "Plug-in",
      "t_os"   = "One-step",
      "t_tmle" = "TMLE"
    )
  ) %>%
  group_by(n, t, estimator) %>%
  summarise(
    coverage = mean(covered, na.rm = TRUE),
    .groups  = "drop"
  )


results_psi_coverage %>%
  filter(
    t %in% t_psi_subset
  ) %>%
  ggplot(aes(
    x = n, y = coverage,
    color = estimator,
    shape = estimator,
    linetype = estimator
  )) +
  geom_line(linewidth = 0.9) +
  geom_point(size = 2) +
  geom_hline(yintercept = 0.95, linetype = "dashed", color = "black") +
  facet_grid(~ t, scales = "free_y") +
  theme_bw() +
  labs(
    x = "Sample Size (n)",
    y = "Empirical 95% Coverage",
    color = "Estimator",
    linetype = "Estimator",
    shape = "Estimator",
    title = latex2exp::TeX("Empirical Coverage of 95\\% Wald Confidence Intervals for $\\bar{\\Psi}_t(P_n)$")
  ) +
  scale_color_manual(
    values = c(
      "Plug-in" = "plum",
      "One-step" = "#0072B2",
      "TMLE" = "#D55E00"
    )
  ) +
  scale_linetype_manual(
    values = c(
      "Plug-in" = "dotted",
      "One-step" = "solid",
      "TMLE" = "longdash"
    )
  ) +
  coord_cartesian(ylim = c(0, 1.0)) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )

results_psi_coverage %>%
  ggplot(aes(
    x = t, y = coverage,
    color    = estimator,
    linetype = estimator,
    shape    = estimator
  )) +
  geom_line(linewidth = 1) +
  geom_point(size = 2) +
  geom_hline(yintercept = 0.95) +
  facet_grid( ~ n, scales = "free_y") +
  theme_bw() +
  labs(
    x        = "t",
    y        = "Empirical 95% Coverage",
    color    = "Estimator",
    linetype = "Estimator",
    shape    = "Estimator",
    title    = latex2exp::TeX("Empirical Coverage of 95\\% Confidence Intervals for $\\bar{\\Psi}_{t,0}$")
  ) +
  scale_color_manual(values = c(
    "Plug-in" = "plum",
    "One-step" = "#0072B2",
    "TMLE"    = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in" = "dotted",
    "One-step" = "solid",
    "TMLE"    = "longdash"
  )) +
  scale_shape_manual(values = c(
    "Plug-in" = 16,  # filled circle
    "One-step" = 17, # filled triangle
    "TMLE"    = 15   # filled square
  )) +
  theme(
    aspect.ratio = 1,
    axis.text.x  = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )
results_psi_ciwidth <- results_psi %>%
  left_join(
    parameter_df,
    by = c("t")
  ) %>%
  select(
    t, n, bar_Psi_t_P0,
    starts_with("bar_Psi_"),
    SD_bar_D_tn, SD_Psi_Pn
  ) %>%
  # Long format for the three estimators' point estimates
  pivot_longer(
    cols         = c(bar_Psi_Pn, bar_Psi_t_os), # bar_Psi_t_tmle),
    names_to     = "estimator",
    names_prefix = "bar_Psi_",
    values_to    = "Psi_hat"
  ) %>%
  mutate(
    # use plug-in SD for plug-in confidence intervals
    se = if_else(estimator == "n", SD_Psi_Pn/n^(1/2), SD_bar_D_tn / n^(1/2)),
    
    # Wald-type CIs
    ci_lower    = Psi_hat   - 1.96 * se,
    ci_upper    = Psi_hat   + 1.96 * se,
    
    # CI width
    ci_width = ci_upper - ci_lower,
    
    # Nice estimator labels for plotting
    estimator   = dplyr::recode(
      estimator,
      "Pn"    = "Plug-in",
      "t_os"   = "One-step",
      "t_tmle" = "TMLE"
    )
  ) %>%
  group_by(n, t, estimator) %>%
  summarise(
    mean_ci_width   = mean(ci_width, na.rm = TRUE),
    median_ci_width = median(ci_width, na.rm = TRUE),
    .groups = "drop"
  )

results_psi_ciwidth %>%
  filter(t %in% t_theta_subset) %>%
  ggplot(
    aes(
      x        = n,
      y        = mean_ci_width,   # swap to median_ci_width if you prefer
      color    = estimator,
      linetype = estimator,
      shape    = estimator
    )
  ) +
  geom_line(linewidth = 1) +
  geom_point(size = 2) +
  geom_hline(yintercept = 0, color = "black") +
  facet_grid( ~ t) +
  theme_bw() +
  labs(
    x        = "Sample Size",
    y        = "Mean 95% CI Width",
    color    = "Estimator",
    linetype = "Estimator",
    shape    = "Estimator",
    title    = latex2exp::TeX("Mean Width of 95\\% Confidence Intervals for $\\bar{\\theta}_{t,0}$")
  ) +
  scale_color_manual(values = c(
    "Plug-in"  = "plum",
    "One-step" = "#0072B2",
    "TMLE"     = "#D55E00"
  )) +
  scale_linetype_manual(values = c(
    "Plug-in"  = "dotted",
    "One-step" = "solid",
    "TMLE"     = "longdash"
  )) +
  scale_shape_manual(values = c(
    "Plug-in"  = 16,
    "One-step" = 17,
    "TMLE"     = 15
  )) +
  theme(
    aspect.ratio = 1,
    axis.text.x  = element_text(angle = 45, hjust = 1),
    legend.position = "bottom"
  )

# Compute √n-standardized quantities (KEEP grid)
results_standardized <- results_psi %>%
  filter(grid == "interval") %>%
  left_join(parameter_df, by = "t") %>%
  mutate(
    z_Pn   = sqrt(n) * (bar_Psi_Pn   - bar_Psi_t_P0) / SD_Psi_Pn,
    z_os   = sqrt(n) * (bar_Psi_t_os - bar_Psi_t_P0) / SD_bar_D_t0
    # z_tmle = sqrt(n) * (bar_Psi_t_tmle - bar_Psi_t_P0) / SD_bar_D_t0
  ) %>%
  select(t, n, starts_with("z_")) %>%   # <-- include grid
  pivot_longer(
    cols = starts_with("z_"),
    names_to = "estimator",
    values_to = "z_value"
  ) %>%
  mutate(
    estimator = recode(estimator,
                       "z_Pn"   = "Plug-in",
                       "z_os"   = "One-step",
                       "z_tmle" = "TMLE"),
    n = factor(n)  # <-- force discrete fill everywhere
  )

# Standard normal reference curve
normal_ref <- tibble(
  x = seq(-4, 4, length.out = 400),
  density = dnorm(x)
)

# Pre-filter once
df_rootn <- results_standardized %>%
  filter(
    n %in% 10000,
    t %in% c(-0.25, 0, 0.25)
  )

# Plot function for one grid value
dist_plot_for_grid_rootn <- function(dat, normal_ref) {
  dat %>%
    ggplot(aes(x = z_value, fill = n)) +
    geom_histogram(
      aes(y = after_stat(density)),
      position = "identity",
      alpha = 0.4,
      color = "black",
      bins = 40
    ) +
    geom_line(
      data = normal_ref,
      aes(x = x, y = density),
      color = "black", linewidth = 1, linetype = "solid",
      inherit.aes = FALSE
    ) +
    geom_vline(xintercept = 0, linetype = "dashed") +
    facet_grid(estimator ~ t, scales = "free") +
    theme_bw() +
    labs(
      x = latex2exp::TeX("$\\sqrt{n}(\\bar{\\psi}_{t,n} - \\bar{\\psi}_{t,0}) / SD(\\bar{D}_t)$"),
      y = "Density",
      fill = "Sample size (n)",
      title = latex2exp::TeX("Estimators of $\\bar{\\Psi}_t(P_n)$ with $N(0,1)$ Overlay")
    ) +
    scale_fill_brewer(palette = "Set2") +
    theme(
      strip.text = element_text(size = 10),
      aspect.ratio = 1,
      legend.position = "bottom"
    )
}

dist_plot_for_grid_rootn(df_rootn, normal_ref)


ggplot(parameter_df, aes(x=t, y = SD_bar_D_t0)) +
  geom_line() +
  theme_bw() +
  labs(
    x = latex2exp::TeX("$t$"),
    y = latex2exp::TeX("SD($\\bar{D}_{t,P}$)"),
    title = "Standard Deviation of the Gradient"
  ) +
  theme(aspect.ratio = 1)

# summarize remainders by n and t, then scale by sqrt(n)
results_remainder_summary <- results_psi %>%
  group_by(t, grid, n) %>%
  summarise(
    # mean
    mean_bar_R1 = mean(bar_R1, na.rm = TRUE),
    mean_R_decomp_1 = mean(R_decomp_1, na.rm = TRUE),
    mean_R_decomp_2 = mean(R_decomp_2, na.rm = TRUE),
    mean_R_decomp_3 = mean(R_decomp_3, na.rm = TRUE),
    mean_R_decomp_4 = mean(R_decomp_4, na.rm = TRUE),
    mean_bar_R_total = mean(bar_R_total, na.rm = TRUE),
    
    # standard deviation
    sd_bar_R1 = sd(bar_R1, na.rm = TRUE),
    sd_R_decomp_1 = sd(R_decomp_1, na.rm = TRUE),
    sd_R_decomp_2 = sd(R_decomp_2, na.rm = TRUE),
    sd_R_decomp_3 = sd(R_decomp_3, na.rm = TRUE),
    sd_R_decomp_4 = sd(R_decomp_4, na.rm = TRUE),
    sd_bar_R_total = sd(bar_R_total, na.rm = TRUE)
  ) %>%
  # scale by √n
  mutate(
    across(starts_with("mean_"), ~ sqrt(n) * .x),
    across(starts_with("sd_"), ~ sqrt(n) * .x)
  )

# plot √n-scaled empirical mean of remainders
results_remainder_summary %>%
  pivot_longer(
    cols = starts_with("mean_"),
    names_to = "remainder",
    values_to = "scaled_mean_value"
  ) %>%
  # tidy up labels
  mutate(
    remainder = recode(remainder,
                       "mean_bar_R1" = "bar_R1 (nuisance error)",
                       "mean_R_decomp_1" = "R1: underestimation",
                       "mean_R_decomp_2" = "R2: missed mass below t",
                       "mean_R_decomp_3" = "R3: overestimation",
                       "mean_R_decomp_4" = "R4: boundary term",
                       "mean_bar_R_total" = "Total remainder"
    )
  ) %>%
  # filter to subset of t
  filter(
    t %in% t_psi_subset
  ) %>%
  ggplot(aes(x = n, y = scaled_mean_value, color = remainder, linetype = remainder)) +
  geom_line(linewidth = 0.9) +
  facet_grid(grid ~ t, scales = "free_y") +
  geom_hline(yintercept = 0, linetype = "dashed") +
  theme_bw() +
  labs(
    x = "log_10(n)",
    y = latex2exp::TeX("$\\sqrt{n} \\times$ Empirical mean remainder"),
    color = "Remainder component",
    linetype = "Remainder component",
    title = latex2exp::TeX("Root-n Scaled Empirical Mean of Remainder Terms, Faceted by $t$")
  ) +
  scale_color_manual(values = c(
    "bar_R1 (nuisance error)" = "plum",
    "R1: underestimation" = "#0072B2",
    "R2: missed mass below t" = "#D55E00",
    "R3: overestimation" = "#009E73",
    "R4: boundary term" = "black",
    "Total remainder" = "purple"
  )) +
  scale_linetype_manual(values = c(
    "bar_R1 (nuisance error)" = "dotted",
    "R1: underestimation" = "solid",
    "R2: missed mass below t" = "solid",
    "R3: overestimation" = "solid",
    "R4: boundary term" = "dotdash",
    "Total remainder" = "longdash"
  )) +
  scale_x_log10(
    breaks = scales::trans_breaks("log10", function(x) 10^x),
    labels = scales::trans_format("log10", scales::math_format(10^.x))
  ) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1)
  )
# SD of remainders by n and t, then scale by sqrt(n)
# plot √n-scaled empirical mean of remainders
results_remainder_summary %>%
  pivot_longer(
    cols = starts_with("sd_"),
    names_to = "remainder",
    values_to = "scaled_sd_value"
  ) %>%
  # tidy up labels
  mutate(
    remainder = recode(remainder,
                       "sd_bar_R1" = "bar_R1 (nuisance error)",
                       "sd_R_decomp_1" = "R1: underestimation",
                       "sd_R_decomp_2" = "R2: missed mass below t",
                       "sd_R_decomp_3" = "R3: overestimation",
                       "sd_R_decomp_4" = "R4: boundary term",
                       "sd_bar_R_total" = "Total remainder"
    )
  ) %>%
  # filter to subset of t
  filter(
    t %in% t_theta_subset
  ) %>%
  ggplot(aes(x = n, y = scaled_sd_value, color = remainder, linetype = remainder)) +
  geom_line(linewidth = 0.9) +
  facet_grid(grid ~ t) +
  geom_hline(yintercept = 0, linetype = "dashed") +
  theme_bw() +
  labs(
    x = "log_10(n)",
    y = latex2exp::TeX("$\\sqrt{n} \\times$ Var(Remainder)"),
    color = "Remainder component",
    linetype = "Remainder component",
    title = latex2exp::TeX("Root-n Scaled Empirical Variance of Remainder"),
    subtitle =  "Faceted by t"
  ) +
  scale_color_manual(values = c(
    "bar_R1 (nuisance error)" = "plum",
    "R1: underestimation" = "#0072B2",
    "R2: missed mass below t" = "#D55E00",
    "R3: overestimation" = "#009E73",
    "R4: boundary term" = "black",
    "Total remainder" = "purple"
  )) +
  scale_linetype_manual(values = c(
    "bar_R1 (nuisance error)" = "dotted",
    "R1: underestimation" = "solid",
    "R2: missed mass below t" = "solid",
    "R3: overestimation" = "solid",
    "R4: boundary term" = "dotdash",
    "Total remainder" = "longdash"
  )) +
  scale_x_log10(
    breaks = scales::trans_breaks("log10", function(x) 10^x),
    labels = scales::trans_format("log10", scales::math_format(10^.x))
  ) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1)
  )

results_theta_bias <- results_theta %>%
  select(
    t, n, grid,
    bar_theta_n, bar_theta_os, # bar_theta_tmle,
    E_barcn_bar_taun_t, f_nbar_taun, # bar_kappan,
    bar_rhon
  ) %>%
  left_join(
    parameter_df,
    by = c("t")
  ) %>%
  group_by(t, n, grid) %>%
  summarise(
    across(where(is.numeric), ~ mean(.x, na.rm = TRUE)),
    .groups = "drop"
  ) %>%
  mutate(
    bias_n = bar_theta_n - bar_theta0,
    bias_os = bar_theta_os - bar_theta0,
    # bias_tmle = bar_theta_tmle - bar_theta0,
    cbrt_n_bias_n    = n^(1/3) * bias_n,
    cbrt_n_bias_os    = n^(1/3) * bias_os,
    # cbrt_n_bias_tmle  = n^(1/3) * bias_tmle,
    # chernoff constant parameters
    bias_E_hatc = E_barcn_bar_taun_t - E_barc0_bar_tau0_t,
    bias_f_bartau = f_nbar_taun - f_0bar_tau0_t,
    # bias_bar_kappa = bar_kappan - bar_kappa0,
    bias_bar_rho = bar_rhon - bar_rho0
  )

# tibble(
#   "W1" = seq(0, 1, by=0.001)
# ) %>%
#   mutate(
#     bar_tau = beta1 + beta5*W1,
#     # Unif(beta1, beta1+beta5)
#     f_0bar_tau0_t = 1 / (max(bar_tau) - min(bar_tau))
#   )

ggplot(parameter_df, aes(x=t, y = f_0bar_tau0_t)) +
  geom_line() +
  theme_bw() +
  labs(
    x = latex2exp::TeX("$x$"),
    y = latex2exp::TeX("$f_{\\bar{\\tau}_P}(x)$"),
    title = "Density of V-Specific CATE"
  ) +
  theme(aspect.ratio = 1)

# bias_f_bartau
results_theta_bias %>%
  filter(
    t %in% t_theta_subset
  ) %>%
  ggplot(aes(x = n)) +
  geom_line(aes(y = bias_f_bartau, color = "f_bartau", linetype = "f_bartau")) +
  facet_grid(grid ~ t) +
  theme_bw() +
  geom_hline(yintercept = 0) +
  labs(
    x = "log_10(n)",
    y = "Bias",
    color = "Estimator",
    linetype = "Estimator",
    title = latex2exp::TeX("Bias f_0")
  ) +
  scale_x_log10(
    breaks = scales::trans_breaks("log10", function(x) 10^x),
    labels = scales::trans_format("log10", scales::math_format(10^.x))
  ) +
  scale_color_manual(values = c("f_bartau" = "#D55E00")) +
  scale_linetype_manual(values = c("f_bartau" = "longdash")) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1)
  )
data.frame(
  barc0V,
  V
) %>%
  ggplot(aes(x=V, y = barc0V)) +
  geom_line() +
  theme_bw() +
  labs(
    x = latex2exp::TeX("$V=W_1$"),
    y = latex2exp::TeX("$\\bar{c}_0(V)$")
  ) +
  theme(aspect.ratio = 1)

data.frame(
  E_barc0_bar_tau0_t_vals,
  param_tvals
) %>%
  ggplot(aes(x=parameter_df$t, y = parameter_df$E_barc0_bar_tau0_t)) +
  geom_line() +
  theme_bw() +
  labs(
    x = latex2exp::TeX("$\\bar{\\tau}_0(V)$"),
    y = latex2exp::TeX("$E[\\bar{c}_0(V) | \\bar{\\tau}_0(v) = t]$")
  ) +
  theme(aspect.ratio = 1)

# E_barc_bartau bias
results_theta_bias %>%
  filter(
    t %in% t_theta_subset
  ) %>%
  ggplot(aes(x = n)) +
  geom_line(aes(y = bias_E_hatc, color = "bias_E_hatc_bartau", linetype = "bias_E_hatc_bartau")) +
  facet_grid(grid ~ t) +
  theme_bw() +
  geom_hline(yintercept = 0) +
  labs(
    x = "n",
    y = "Bias",
    color = "Estimator",
    linetype = "Estimator",
    title = latex2exp::TeX("Bias $E[\\bar{c}_0(V) | \\bar{\\tau}_0(v) = t]$")
  ) +
  # scale_x_log10(
  #   breaks = scales::trans_breaks("log10", function(x) 10^x),
  #   labels = scales::trans_format("log10", scales::math_format(10^.x))
  # ) +
  scale_color_manual(values = c("bias_E_hatc_bartau" = "#D55E00")) +
  scale_linetype_manual(values = c("bias_E_hatc_bartau" = "longdash")) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1)
  )
parameter_df %>%
  ggplot(aes(x=t, y = bar_rho0)) +
  geom_line() +
  theme_bw() +
  labs(
    x = latex2exp::TeX("$\\bar{\\tau}_0(V)$"),
    y = latex2exp::TeX("$\\bar{\\rho}_0(t)$")
  ) +
  theme(aspect.ratio = 1)

# bar_rho bias
results_theta_bias %>%
  filter(
    t %in% t_theta_subset
  ) %>%
  ggplot(aes(x = n)) +
  geom_line(aes(y = bias_bar_rho, color = "bar_rho_n", linetype = "bar_rho_n")) +
  facet_grid(grid ~ t) +
  theme_bw() +
  geom_hline(yintercept = 0) +
  labs(
    x = "log_10(n)",
    y = "Bias",
    color = "Estimator",
    linetype = "Estimator",
    title = latex2exp::TeX("Bias \\bar{\\rho}")
  ) +
  scale_x_log10(
    breaks = scales::trans_breaks("log10", function(x) 10^x),
    labels = scales::trans_format("log10", scales::math_format(10^.x))
  ) +
  scale_color_manual(values = c("bar_rho_n" = "#D55E00")) +
  scale_linetype_manual(values = c("bar_rho_n" = "longdash")) +
  theme(
    aspect.ratio = 1,
    axis.text.x = element_text(angle = 45, hjust = 1)
  )